{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial 7: ARGA & ARVGA  \n",
    "\n",
    "Paper:\n",
    "* [Adversarially Regularized Graph Autoencoder for Graph Embedding](https://www.ijcai.org/Proceedings/2018/0362.pdf)  \n",
    "\n",
    "Code:\n",
    " * [ARGA & ARVGA](https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/models/autoencoder.html)\n",
    " * [Example on clustering](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/argva_node_clustering.py)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os.path as osp\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from sklearn.cluster import KMeans\n",
    "from sklearn.metrics.cluster import (v_measure_score, homogeneity_score, completeness_score)\n",
    "from sklearn.manifold import TSNE\n",
    "import matplotlib.pyplot as plt\n",
    "import torch_geometric.transforms as T\n",
    "from torch_geometric.datasets import Planetoid\n",
    "from torch_geometric.nn import GCNConv\n",
    "from torch_geometric.nn.models.autoencoder import ARGVA\n",
    "from torch_geometric.utils import train_test_split_edges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "use_cuda = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define the dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Download the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = 'Cora'\n",
    "path = osp.join('.', 'data', dataset)\n",
    "dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())\n",
    "data = dataset.get(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get the number of nodes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_nodes = data.x.shape[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create the train/val/test data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data.train_mask = data.val_mask = data.test_mask = None\n",
    "data = train_test_split_edges(data)\n",
    "data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define the model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define the encoder classes (the same as in `Tutorial 6`)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class VEncoder(torch.nn.Module):\n",
    "    def __init__(self, in_channels, out_channels):\n",
    "        super(VEncoder, self).__init__()\n",
    "        self.conv1 = GCNConv(in_channels, 2 * out_channels, cached=True)\n",
    "        self.conv_mu = GCNConv(2 * out_channels, out_channels, cached=True)\n",
    "        self.conv_logstd = GCNConv(2 * out_channels, out_channels, cached=True)\n",
    "\n",
    "    def forward(self, x, edge_index):\n",
    "        x = F.relu(self.conv1(x, edge_index))\n",
    "        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define the discriminator class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Discriminator(torch.nn.Module):\n",
    "    def __init__(self, in_channels, hidden_channels, out_channels):\n",
    "        super(Discriminator, self).__init__()\n",
    "        self.lin1 = torch.nn.Linear(in_channels, hidden_channels)\n",
    "        self.lin2 = torch.nn.Linear(hidden_channels, hidden_channels)\n",
    "        self.lin3 = torch.nn.Linear(hidden_channels, out_channels)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.relu(self.lin1(x))\n",
    "        x = F.relu(self.lin2(x))\n",
    "        x = self.lin3(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define the training algorithm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train():\n",
    "    model.train()\n",
    "    encoder_optimizer.zero_grad()\n",
    "    \n",
    "    z = model.encode(data.x, data.train_pos_edge_index)\n",
    "\n",
    "    for i in range(5):\n",
    "        idx = range(num_nodes)  \n",
    "        discriminator.train()\n",
    "        discriminator_optimizer.zero_grad()\n",
    "        discriminator_loss = model.discriminator_loss(z[idx]) # Comment\n",
    "        discriminator_loss.backward(retain_graph=True)\n",
    "        discriminator_optimizer.step()\n",
    " \n",
    "    loss = 0\n",
    "    loss = loss + model.reg_loss(z)  # Comment\n",
    "    \n",
    "    loss = loss + model.recon_loss(z, data.train_pos_edge_index)\n",
    "    loss = loss + (1 / data.num_nodes) * model.kl_loss()\n",
    "    loss.backward()\n",
    "\n",
    "    encoder_optimizer.step()\n",
    "    \n",
    "    return loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define a training test method"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def test():\n",
    "    model.eval()\n",
    "    z = model.encode(data.x, data.train_pos_edge_index)\n",
    "\n",
    "    # Cluster embedded values using k-means.\n",
    "    kmeans_input = z.cpu().numpy()\n",
    "    kmeans = KMeans(n_clusters=7, random_state=0).fit(kmeans_input)\n",
    "    pred = kmeans.predict(kmeans_input)\n",
    "\n",
    "    labels = data.y.cpu().numpy()\n",
    "    completeness = completeness_score(labels, pred)\n",
    "    hm = homogeneity_score(labels, pred)\n",
    "    nmi = v_measure_score(labels, pred)\n",
    "\n",
    "    auc, ap = model.test(z, data.test_pos_edge_index, data.test_neg_edge_index)\n",
    "\n",
    "    return auc, ap, completeness, hm, nmi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize the model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Initialize an encoder and a discriminator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "latent_size = 32\n",
    "encoder = VEncoder(data.num_features, out_channels=latent_size)\n",
    "\n",
    "discriminator = Discriminator(in_channels=latent_size, hidden_channels=64, \n",
    "                              out_channels=1) # Comment"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Initialize the model and move everything to the GPU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = ARGVA(encoder, discriminator)\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() and use_cuda else 'cpu')\n",
    "model, data = model.to(device), data.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define the optimizers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.001)\n",
    "encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=0.005)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for epoch in range(1, 201):\n",
    "    loss = train()\n",
    "    auc, ap, completeness, hm, nmi = test()\n",
    "    print((f'Epoch: {epoch:03d}, Loss: {loss:.3f}, AUC: {auc:.3f}, '\n",
    "           f'AP: {ap:.3f}, Completeness: {completeness:.3f}, '\n",
    "           f'Homogeneity: {hm:.3f}, NMI: {nmi:.3f}'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def plot_points(colors):\n",
    "    model.eval()\n",
    "    z = model.encode(data.x, data.train_pos_edge_index)\n",
    "    z = TSNE(n_components=2).fit_transform(z.cpu().numpy())\n",
    "    y = data.y.cpu().numpy()\n",
    "\n",
    "    fig = plt.figure(1, figsize=(8, 8))\n",
    "    fig.clf()\n",
    "    for i in range(dataset.num_classes):\n",
    "        plt.scatter(z[y == i, 0], z[y == i, 1], s=20, color=colors[i])\n",
    "    plt.axis('off')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#%%\n",
    "colors = [\n",
    "    '#ffc0cb', '#bada55', '#008080', '#420420', '#7fe5f0', '#065535', '#ffd700'\n",
    "]\n",
    "plot_points(colors)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
